{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<center>\n",
    "    <p style=\"text-align:center\">\n",
    "        <img alt=\"phoenix logo\" src=\"https://storage.googleapis.com/arize-phoenix-assets/assets/phoenix-logo-light.svg\" width=\"200\"/>\n",
    "        <br>\n",
    "        <a href=\"https://arize.com/docs/phoenix/\">Docs</a>\n",
    "        |\n",
    "        <a href=\"https://github.com/Arize-ai/phoenix\">GitHub</a>\n",
    "        |\n",
    "        <a href=\"https://arize-ai.slack.com/join/shared_invite/zt-2w57bhem8-hq24MB6u7yE_ZF_ilOYSBw#/shared-invite/email\">Community</a>\n",
    "    </p>\n",
    "</center>\n",
    "\n",
    "# Google GenAI SDK - Building an Orchestrator Agent"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q google-genai arize-phoenix-otel openinference-instrumentation-google-genai"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Connect to Arize Phoenix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from getpass import getpass\n",
    "\n",
    "from google import genai\n",
    "from google.genai import types\n",
    "\n",
    "from phoenix.otel import register\n",
    "\n",
    "if \"PHOENIX_API_KEY\" not in os.environ:\n",
    "    os.environ[\"PHOENIX_API_KEY\"] = getpass(\"🔑 Enter your Phoenix API key: \")\n",
    "\n",
    "if \"PHOENIX_COLLECTOR_ENDPOINT\" not in os.environ:\n",
    "    os.environ[\"PHOENIX_COLLECTOR_ENDPOINT\"] = getpass(\"🔑 Enter your Phoenix Collector Endpoint\")\n",
    "\n",
    "tracer_provider = register(auto_instrument=True, project_name=\"google-genai-orchestrator-agent\")\n",
    "tracer = tracer_provider.get_tracer(__name__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Authenticate with Google Vertex AI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!gcloud auth login"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a client using the Vertex AI API, you could also use the Google GenAI API instead here\n",
    "client = genai.Client(vertexai=True, project=\"<ADD YOUR GCP PROJECT ID>\", location=\"us-central1\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Orchestration Agent"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, define the sub agents, or in this case tools, that the orchestrator can choose between."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define models for different specialized agents\n",
    "FLASH_MODEL = \"gemini-2.0-flash-001\"\n",
    "\n",
    "\n",
    "@tracer.chain()\n",
    "def call_user_proxy_agent(query, context=\"\"):\n",
    "    \"\"\"User proxy agent that acts as the user and gives feedback.\"\"\"\n",
    "    prompt = f\"\"\"You are a user proxy assistant. Provide feedback as if you were the user on:\n",
    "    Context: {context}\n",
    "    Query: {query}\n",
    "    Give honest, constructive feedback from a user's perspective.\"\"\"\n",
    "\n",
    "    response = client.models.generate_content(\n",
    "        model=FLASH_MODEL,\n",
    "        contents=prompt,\n",
    "    )\n",
    "    return response.text.strip()\n",
    "\n",
    "\n",
    "@tracer.chain()\n",
    "def call_flight_planning_agent(query, context=\"\"):\n",
    "    \"\"\"Flight planning agent that helps find and recommend flights.\"\"\"\n",
    "    prompt = f\"\"\"You are a flight planning assistant. Help plan flights for:\n",
    "    Context: {context}\n",
    "    Query: {query}\n",
    "    Provide detailed flight options with considerations for price, timing, and convenience.\"\"\"\n",
    "\n",
    "    response = client.models.generate_content(\n",
    "        model=FLASH_MODEL,\n",
    "        contents=prompt,\n",
    "    )\n",
    "    return response.text.strip()\n",
    "\n",
    "\n",
    "@tracer.chain()\n",
    "def call_hotel_recommendation_agent(query, context=\"\"):\n",
    "    \"\"\"Hotel recommendation agent that suggests accommodations.\"\"\"\n",
    "    prompt = f\"\"\"You are a hotel recommendation assistant. Suggest accommodations for:\n",
    "    Context: {context}\n",
    "    Query: {query}\n",
    "    Provide suitable hotel options with details on amenities, location, and price ranges.\"\"\"\n",
    "\n",
    "    response = client.models.generate_content(\n",
    "        model=FLASH_MODEL,\n",
    "        contents=prompt,\n",
    "    )\n",
    "    return response.text.strip()\n",
    "\n",
    "\n",
    "@tracer.chain()\n",
    "def call_travel_attraction_agent(query, context=\"\"):\n",
    "    \"\"\"Travel attraction recommendation agent that suggests places to visit.\"\"\"\n",
    "    prompt = f\"\"\"You are a travel attraction recommendation assistant. Suggest attractions for:\n",
    "    Context: {context}\n",
    "    Query: {query}\n",
    "    Provide interesting places to visit with descriptions, highlights, and practical information.\"\"\"\n",
    "\n",
    "    response = client.models.generate_content(\n",
    "        model=FLASH_MODEL,\n",
    "        contents=prompt,\n",
    "    )\n",
    "    return response.text.strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tracer.chain()\n",
    "def determine_next_step(user_query, context, cycle, max_cycles):\n",
    "    \"\"\"\n",
    "    Determines the next agent to call based on the current context and user query.\n",
    "    Args:\n",
    "        user_query: The initial user query\n",
    "        context: Current accumulated context\n",
    "        cycle: Current cycle number\n",
    "        max_cycles: Maximum number of agent calls\n",
    "    Returns:\n",
    "        The function name to call next\n",
    "    \"\"\"\n",
    "    orchestration_prompt = f\"\"\"You are an orchestration agent. Decide the next step to take:\n",
    "    User query: {user_query}\n",
    "    Current context: {context}\n",
    "    Current cycle: {cycle}/{max_cycles}\n",
    "    Choose one of the available tools to help address the user query, or decide to return a final answer.\n",
    "    \"\"\"\n",
    "\n",
    "    # Define orchestrator tools\n",
    "    orchestrator_tools = {\n",
    "        \"function_declarations\": [\n",
    "            {\n",
    "                \"name\": \"call_planning_agent\",\n",
    "                \"description\": \"Call planning agent to create a structured plan with next steps\",\n",
    "            },\n",
    "            {\n",
    "                \"name\": \"call_flight_planning_agent\",\n",
    "                \"description\": \"Call flight planning agent to help find and recommend flights\",\n",
    "            },\n",
    "            {\n",
    "                \"name\": \"call_hotel_recommendation_agent\",\n",
    "                \"description\": \"Call hotel recommendation agent to suggest accommodations\",\n",
    "            },\n",
    "            {\n",
    "                \"name\": \"call_travel_attraction_agent\",\n",
    "                \"description\": \"Call travel attraction agent to suggest interesting places to visit with descriptions\",\n",
    "            },\n",
    "            {\n",
    "                \"name\": \"call_user_proxy_agent\",\n",
    "                \"description\": \"Call user proxy agent that acts as the user and gives feedback\",\n",
    "            },\n",
    "            {\n",
    "                \"name\": \"return_final_answer\",\n",
    "                \"description\": \"Return to user with final answer when sufficient information has been gathered\",\n",
    "            },\n",
    "        ]\n",
    "    }\n",
    "\n",
    "    orchestration_response = client.models.generate_content(\n",
    "        model=FLASH_MODEL,\n",
    "        contents=orchestration_prompt,\n",
    "        config=types.GenerateContentConfig(tools=[orchestrator_tools]),\n",
    "    )\n",
    "\n",
    "    if orchestration_response.candidates[0].content.parts[0].function_call:\n",
    "        function_call = orchestration_response.candidates[0].content.parts[0].function_call\n",
    "        return function_call.name\n",
    "    else:\n",
    "        return \"return_final_answer\"  # Default to returning final answer if no tool called\n",
    "\n",
    "\n",
    "@tracer.chain()\n",
    "def execute_agent_call(function_name, user_query, context):\n",
    "    \"\"\"\n",
    "    Executes the specified agent call and returns the response and agent type.\n",
    "    Args:\n",
    "        function_name: The name of the function to call\n",
    "        user_query: The initial user query\n",
    "        context: Current accumulated context\n",
    "    Returns:\n",
    "        Tuple of (agent_response, agent_type)\n",
    "    \"\"\"\n",
    "    if function_name == \"call_flight_planning_agent\":\n",
    "        agent_response = call_flight_planning_agent(user_query, context)\n",
    "        agent_type = \"Flight Planning\"\n",
    "    elif function_name == \"call_hotel_recommendation_agent\":\n",
    "        agent_response = call_hotel_recommendation_agent(user_query, context)\n",
    "        agent_type = \"Hotel Recommendation\"\n",
    "    elif function_name == \"call_travel_attraction_agent\":\n",
    "        agent_response = call_travel_attraction_agent(user_query, context)\n",
    "        agent_type = \"Travel Attraction\"\n",
    "    elif function_name == \"call_user_proxy_agent\":\n",
    "        agent_response = call_user_proxy_agent(user_query, context)\n",
    "        agent_type = \"User Proxy\"\n",
    "    else:\n",
    "        agent_response = \"\"\n",
    "        agent_type = \"Unknown\"\n",
    "\n",
    "    return agent_response, agent_type\n",
    "\n",
    "\n",
    "@tracer.chain()\n",
    "def generate_final_answer(user_query, context, max_cycles_reached=False):\n",
    "    \"\"\"\n",
    "    Generates a final answer based on the accumulated context.\n",
    "    Args:\n",
    "        user_query: The initial user query\n",
    "        context: Current accumulated context\n",
    "        max_cycles_reached: Whether the maximum cycles were reached\n",
    "    Returns:\n",
    "        Final response to the user\n",
    "    \"\"\"\n",
    "    final_prompt = f\"\"\"Create a final response to the user query: {user_query}\n",
    "    Based on this context: {context}\n",
    "    \"\"\"\n",
    "\n",
    "    if max_cycles_reached:\n",
    "        final_prompt += \"\\n\\nProvide a comprehensive and helpful answer, noting that we've reached our maximum processing cycles.\"\n",
    "    else:\n",
    "        final_prompt += \"\\n\\nProvide a comprehensive and helpful answer.\"\n",
    "\n",
    "    final_response = client.models.generate_content(\n",
    "        model=FLASH_MODEL,\n",
    "        contents=final_prompt,\n",
    "    )\n",
    "\n",
    "    return final_response.text.strip()\n",
    "\n",
    "\n",
    "@tracer.agent()\n",
    "def orchestrator(user_query, max_cycles=3):\n",
    "    \"\"\"\n",
    "    Orchestrator that decides which agent to call at each step of the process.\n",
    "    Args:\n",
    "        user_query: The initial user query\n",
    "        max_cycles: Maximum number of agent calls before returning to user\n",
    "    Returns:\n",
    "        Final response to the user\n",
    "    \"\"\"\n",
    "    context = \"\"\n",
    "    cycle = 0\n",
    "\n",
    "    while cycle < max_cycles:\n",
    "        # Determine next step\n",
    "        function_name = determine_next_step(user_query, context, cycle, max_cycles)\n",
    "\n",
    "        if function_name == \"return_final_answer\":\n",
    "            return generate_final_answer(user_query, context)\n",
    "\n",
    "        # Execute the agent call\n",
    "        agent_response, agent_type = execute_agent_call(function_name, user_query, context)\n",
    "\n",
    "        # Update context with agent response\n",
    "        context += f\"\\n\\n{agent_type} Agent Output:\\n{agent_response}\"\n",
    "        cycle += 1\n",
    "\n",
    "    # If max cycles reached, return what we have\n",
    "    return generate_final_answer(user_query, context, max_cycles_reached=True)\n",
    "\n",
    "\n",
    "# Example usage\n",
    "user_query = \"\"\"I want to plan a 5-day trip to Paris, France, sometime in October. I'm interested\n",
    "in museums and good food. Find flight options from SFO, suggest mid-range hotels near the city center,\n",
    "and recommend some relevant activities.\"\"\"\n",
    "\n",
    "response = orchestrator(user_query)\n",
    "print(response)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
